{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Python Version of Tree SHAP\n",
    "\n",
    "This is a sample implementation of Tree SHAP written in Python for easy reading."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:45.176673980Z",
     "start_time": "2023-10-14T15:13:42.734552921Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import numba\n",
    "import numpy as np\n",
    "import sklearn.ensemble\n",
    "import xgboost\n",
    "\n",
    "import shap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Load California dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:45.191494227Z",
     "start_time": "2023-10-14T15:13:45.179770187Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": "(1000, 8)"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y = shap.datasets.california(n_points=1000)\n",
    "X.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Train sklearn random forest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:46.873411331Z",
     "start_time": "2023-10-14T15:13:45.219992128Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestRegressor(max_depth=4, n_estimators=1000)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestRegressor</label><div class=\"sk-toggleable__content\"><pre>RandomForestRegressor(max_depth=4, n_estimators=1000)</pre></div></div></div></div></div>",
      "text/plain": "RandomForestRegressor(max_depth=4, n_estimators=1000)"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = sklearn.ensemble.RandomForestRegressor(n_estimators=1000, max_depth=4)\n",
    "model.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Train XGBoost model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:47.891842198Z",
     "start_time": "2023-10-14T15:13:46.874216599Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.\n",
      "is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n",
      "is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n",
      "is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n"
     ]
    }
   ],
   "source": [
    "bst = xgboost.train({\"learning_rate\": 0.01, \"max_depth\": 4}, xgboost.DMatrix(X, label=y), 1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Python TreeExplainer\n",
    "\n",
    "This uses numba to speed things up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:47.899642452Z",
     "start_time": "2023-10-14T15:13:47.898748678Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class TreeExplainer:\n",
    "    def __init__(self, model, **kwargs):\n",
    "        if str(type(model)).endswith(\"sklearn.ensemble._forest.RandomForestRegressor'>\"):\n",
    "            # self.trees = [Tree(e.tree_) for e in model.estimators_]\n",
    "            self.trees = [\n",
    "                Tree(\n",
    "                    children_left=e.tree_.children_left,\n",
    "                    children_right=e.tree_.children_right,\n",
    "                    children_default=e.tree_.children_right,\n",
    "                    feature=e.tree_.feature,\n",
    "                    threshold=e.tree_.threshold,\n",
    "                    value=e.tree_.value[:, 0, 0],\n",
    "                    node_sample_weight=e.tree_.weighted_n_node_samples,\n",
    "                )\n",
    "                for e in model.estimators_\n",
    "            ]\n",
    "\n",
    "        # Preallocate space for the unique path data\n",
    "        maxd = np.max([t.max_depth for t in self.trees]) + 2\n",
    "        s = (maxd * (maxd + 1)) // 2\n",
    "        self.feature_indexes = np.zeros(s, dtype=np.int32)\n",
    "        self.zero_fractions = np.zeros(s, dtype=np.float64)\n",
    "        self.one_fractions = np.zeros(s, dtype=np.float64)\n",
    "        self.pweights = np.zeros(s, dtype=np.float64)\n",
    "\n",
    "    def shap_values(self, X, **kwargs):\n",
    "        # convert dataframes\n",
    "        if str(type(X)).endswith(\"pandas.core.series.Series'>\"):\n",
    "            X = X.values\n",
    "        elif str(type(X)).endswith(\"'pandas.core.frame.DataFrame'>\"):\n",
    "            X = X.values\n",
    "\n",
    "        assert str(type(X)).endswith(\"'numpy.ndarray'>\"), \"Unknown instance type: \" + str(type(X))\n",
    "        assert len(X.shape) == 1 or len(X.shape) == 2, \"Instance must have 1 or 2 dimensions!\"\n",
    "\n",
    "        # single instance\n",
    "        if len(X.shape) == 1:\n",
    "            phi = np.zeros(X.shape[0] + 1)\n",
    "            x_missing = np.zeros(X.shape[0], dtype=bool)\n",
    "            for t in self.trees:\n",
    "                self.tree_shap(t, X, x_missing, phi)\n",
    "            phi /= len(self.trees)\n",
    "        elif len(X.shape) == 2:\n",
    "            phi = np.zeros((X.shape[0], X.shape[1] + 1))\n",
    "            x_missing = np.zeros(X.shape[1], dtype=bool)\n",
    "            for i in range(X.shape[0]):\n",
    "                for t in self.trees:\n",
    "                    self.tree_shap(t, X[i, :], x_missing, phi[i, :])\n",
    "            phi /= len(self.trees)\n",
    "        return phi\n",
    "\n",
    "    def tree_shap(self, tree, x, x_missing, phi, condition=0, condition_feature=0):\n",
    "        # update the bias term, which is the last index in phi\n",
    "        # (note the paper has this as phi_0 instead of phi_M)\n",
    "        if condition == 0:\n",
    "            phi[-1] += tree.values[0]\n",
    "\n",
    "        # start the recursive algorithm\n",
    "        tree_shap_recursive(\n",
    "            tree.children_left,\n",
    "            tree.children_right,\n",
    "            tree.children_default,\n",
    "            tree.features,\n",
    "            tree.thresholds,\n",
    "            tree.values,\n",
    "            tree.node_sample_weight,\n",
    "            x,\n",
    "            x_missing,\n",
    "            phi,\n",
    "            0,\n",
    "            0,\n",
    "            self.feature_indexes,\n",
    "            self.zero_fractions,\n",
    "            self.one_fractions,\n",
    "            self.pweights,\n",
    "            1,\n",
    "            1,\n",
    "            -1,\n",
    "            condition,\n",
    "            condition_feature,\n",
    "            1,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:51.276386697Z",
     "start_time": "2023-10-14T15:13:47.916741034Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# extend our decision path with a fraction of one and zero extensions\n",
    "@numba.jit(\n",
    "    numba.types.void(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int32,\n",
    "        numba.types.float64,\n",
    "        numba.types.float64,\n",
    "        numba.types.int32,\n",
    "    ),\n",
    "    nopython=True,\n",
    "    nogil=True,\n",
    ")\n",
    "def extend_path(\n",
    "    feature_indexes,\n",
    "    zero_fractions,\n",
    "    one_fractions,\n",
    "    pweights,\n",
    "    unique_depth,\n",
    "    zero_fraction,\n",
    "    one_fraction,\n",
    "    feature_index,\n",
    "):\n",
    "    feature_indexes[unique_depth] = feature_index\n",
    "    zero_fractions[unique_depth] = zero_fraction\n",
    "    one_fractions[unique_depth] = one_fraction\n",
    "    if unique_depth == 0:\n",
    "        pweights[unique_depth] = 1\n",
    "    else:\n",
    "        pweights[unique_depth] = 0\n",
    "\n",
    "    for i in range(unique_depth - 1, -1, -1):\n",
    "        pweights[i + 1] += one_fraction * pweights[i] * (i + 1) / (unique_depth + 1)\n",
    "        pweights[i] = zero_fraction * pweights[i] * (unique_depth - i) / (unique_depth + 1)\n",
    "\n",
    "\n",
    "# undo a previous extension of the decision path\n",
    "@numba.jit(\n",
    "    numba.types.void(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int32,\n",
    "        numba.types.int32,\n",
    "    ),\n",
    "    nopython=True,\n",
    "    nogil=True,\n",
    ")\n",
    "def unwind_path(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):\n",
    "    one_fraction = one_fractions[path_index]\n",
    "    zero_fraction = zero_fractions[path_index]\n",
    "    next_one_portion = pweights[unique_depth]\n",
    "\n",
    "    for i in range(unique_depth - 1, -1, -1):\n",
    "        if one_fraction != 0:\n",
    "            tmp = pweights[i]\n",
    "            pweights[i] = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)\n",
    "            next_one_portion = tmp - pweights[i] * zero_fraction * (unique_depth - i) / (unique_depth + 1)\n",
    "        else:\n",
    "            pweights[i] = (pweights[i] * (unique_depth + 1)) / (zero_fraction * (unique_depth - i))\n",
    "\n",
    "    for i in range(path_index, unique_depth):\n",
    "        feature_indexes[i] = feature_indexes[i + 1]\n",
    "        zero_fractions[i] = zero_fractions[i + 1]\n",
    "        one_fractions[i] = one_fractions[i + 1]\n",
    "\n",
    "\n",
    "# determine what the total permuation weight would be if\n",
    "# we unwound a previous extension in the decision path\n",
    "@numba.jit(\n",
    "    numba.types.float64(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int32,\n",
    "        numba.types.int32,\n",
    "    ),\n",
    "    nopython=True,\n",
    "    nogil=True,\n",
    ")\n",
    "def unwound_path_sum(feature_indexes, zero_fractions, one_fractions, pweights, unique_depth, path_index):\n",
    "    one_fraction = one_fractions[path_index]\n",
    "    zero_fraction = zero_fractions[path_index]\n",
    "    next_one_portion = pweights[unique_depth]\n",
    "    total = 0\n",
    "\n",
    "    for i in range(unique_depth - 1, -1, -1):\n",
    "        if one_fraction != 0:\n",
    "            tmp = next_one_portion * (unique_depth + 1) / ((i + 1) * one_fraction)\n",
    "            total += tmp\n",
    "            next_one_portion = pweights[i] - tmp * zero_fraction * ((unique_depth - i) / (unique_depth + 1))\n",
    "        else:\n",
    "            total += (pweights[i] / zero_fraction) / ((unique_depth - i) / (unique_depth + 1))\n",
    "\n",
    "    return total\n",
    "\n",
    "\n",
    "class Tree:\n",
    "    def __init__(\n",
    "        self,\n",
    "        children_left,\n",
    "        children_right,\n",
    "        children_default,\n",
    "        feature,\n",
    "        threshold,\n",
    "        value,\n",
    "        node_sample_weight,\n",
    "    ):\n",
    "        self.children_left = children_left.astype(np.int32)\n",
    "        self.children_right = children_right.astype(np.int32)\n",
    "        self.children_default = children_default.astype(np.int32)\n",
    "        self.features = feature.astype(np.int32)\n",
    "        self.thresholds = threshold\n",
    "        self.values = value\n",
    "        self.node_sample_weight = node_sample_weight\n",
    "\n",
    "        self.max_depth = compute_expectations(\n",
    "            self.children_left,\n",
    "            self.children_right,\n",
    "            self.node_sample_weight,\n",
    "            self.values,\n",
    "            0,\n",
    "        )\n",
    "\n",
    "\n",
    "@numba.jit(nopython=True)\n",
    "def compute_expectations(children_left, children_right, node_sample_weight, values, i, depth=0):\n",
    "    if children_right[i] == -1:\n",
    "        values[i] = values[i]\n",
    "        return 0\n",
    "    else:\n",
    "        li = children_left[i]\n",
    "        ri = children_right[i]\n",
    "        depth_left = compute_expectations(children_left, children_right, node_sample_weight, values, li, depth + 1)\n",
    "        depth_right = compute_expectations(children_left, children_right, node_sample_weight, values, ri, depth + 1)\n",
    "        left_weight = node_sample_weight[li]\n",
    "        right_weight = node_sample_weight[ri]\n",
    "        v = (left_weight * values[li] + right_weight * values[ri]) / (left_weight + right_weight)\n",
    "        values[i] = v\n",
    "        return max(depth_left, depth_right) + 1\n",
    "\n",
    "\n",
    "# recursive computation of SHAP values for a decision tree\n",
    "@numba.jit(\n",
    "    numba.types.void(\n",
    "        numba.types.int32[:],\n",
    "        numba.types.int32[:],\n",
    "        numba.types.int32[:],\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.boolean[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.int64,\n",
    "        numba.types.int64,\n",
    "        numba.types.int32[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64[:],\n",
    "        numba.types.float64,\n",
    "        numba.types.float64,\n",
    "        numba.types.int64,\n",
    "        numba.types.int64,\n",
    "        numba.types.int64,\n",
    "        numba.types.float64,\n",
    "    ),\n",
    "    nopython=True,\n",
    "    nogil=True,\n",
    ")\n",
    "def tree_shap_recursive(\n",
    "    children_left,\n",
    "    children_right,\n",
    "    children_default,\n",
    "    features,\n",
    "    thresholds,\n",
    "    values,\n",
    "    node_sample_weight,\n",
    "    x,\n",
    "    x_missing,\n",
    "    phi,\n",
    "    node_index,\n",
    "    unique_depth,\n",
    "    parent_feature_indexes,\n",
    "    parent_zero_fractions,\n",
    "    parent_one_fractions,\n",
    "    parent_pweights,\n",
    "    parent_zero_fraction,\n",
    "    parent_one_fraction,\n",
    "    parent_feature_index,\n",
    "    condition,\n",
    "    condition_feature,\n",
    "    condition_fraction,\n",
    "):\n",
    "    # stop if we have no weight coming down to us\n",
    "    if condition_fraction == 0:\n",
    "        return\n",
    "\n",
    "    # extend the unique path\n",
    "    feature_indexes = parent_feature_indexes[unique_depth + 1 :]\n",
    "    feature_indexes[: unique_depth + 1] = parent_feature_indexes[: unique_depth + 1]\n",
    "    zero_fractions = parent_zero_fractions[unique_depth + 1 :]\n",
    "    zero_fractions[: unique_depth + 1] = parent_zero_fractions[: unique_depth + 1]\n",
    "    one_fractions = parent_one_fractions[unique_depth + 1 :]\n",
    "    one_fractions[: unique_depth + 1] = parent_one_fractions[: unique_depth + 1]\n",
    "    pweights = parent_pweights[unique_depth + 1 :]\n",
    "    pweights[: unique_depth + 1] = parent_pweights[: unique_depth + 1]\n",
    "\n",
    "    if condition == 0 or condition_feature != parent_feature_index:\n",
    "        extend_path(\n",
    "            feature_indexes,\n",
    "            zero_fractions,\n",
    "            one_fractions,\n",
    "            pweights,\n",
    "            unique_depth,\n",
    "            parent_zero_fraction,\n",
    "            parent_one_fraction,\n",
    "            parent_feature_index,\n",
    "        )\n",
    "\n",
    "    split_index = features[node_index]\n",
    "\n",
    "    # leaf node\n",
    "    if children_right[node_index] == -1:\n",
    "        for i in range(1, unique_depth + 1):\n",
    "            w = unwound_path_sum(\n",
    "                feature_indexes,\n",
    "                zero_fractions,\n",
    "                one_fractions,\n",
    "                pweights,\n",
    "                unique_depth,\n",
    "                i,\n",
    "            )\n",
    "            phi[feature_indexes[i]] += (\n",
    "                w * (one_fractions[i] - zero_fractions[i]) * values[node_index] * condition_fraction\n",
    "            )\n",
    "\n",
    "    # internal node\n",
    "    else:\n",
    "        # find which branch is \"hot\" (meaning x would follow it)\n",
    "        hot_index = 0\n",
    "        cleft = children_left[node_index]\n",
    "        cright = children_right[node_index]\n",
    "        if x_missing[split_index] == 1:\n",
    "            hot_index = children_default[node_index]\n",
    "        elif x[split_index] < thresholds[node_index]:\n",
    "            hot_index = cleft\n",
    "        else:\n",
    "            hot_index = cright\n",
    "        cold_index = cright if hot_index == cleft else cleft\n",
    "        w = node_sample_weight[node_index]\n",
    "        hot_zero_fraction = node_sample_weight[hot_index] / w\n",
    "        cold_zero_fraction = node_sample_weight[cold_index] / w\n",
    "        incoming_zero_fraction = 1\n",
    "        incoming_one_fraction = 1\n",
    "\n",
    "        # see if we have already split on this feature,\n",
    "        # if so we undo that split so we can redo it for this node\n",
    "        path_index = 0\n",
    "        while path_index <= unique_depth:\n",
    "            if feature_indexes[path_index] == split_index:\n",
    "                break\n",
    "            path_index += 1\n",
    "\n",
    "        if path_index != unique_depth + 1:\n",
    "            incoming_zero_fraction = zero_fractions[path_index]\n",
    "            incoming_one_fraction = one_fractions[path_index]\n",
    "            unwind_path(\n",
    "                feature_indexes,\n",
    "                zero_fractions,\n",
    "                one_fractions,\n",
    "                pweights,\n",
    "                unique_depth,\n",
    "                path_index,\n",
    "            )\n",
    "            unique_depth -= 1\n",
    "\n",
    "        # divide up the condition_fraction among the recursive calls\n",
    "        hot_condition_fraction = condition_fraction\n",
    "        cold_condition_fraction = condition_fraction\n",
    "        if condition > 0 and split_index == condition_feature:\n",
    "            cold_condition_fraction = 0\n",
    "            unique_depth -= 1\n",
    "        elif condition < 0 and split_index == condition_feature:\n",
    "            hot_condition_fraction *= hot_zero_fraction\n",
    "            cold_condition_fraction *= cold_zero_fraction\n",
    "            unique_depth -= 1\n",
    "\n",
    "        tree_shap_recursive(\n",
    "            children_left,\n",
    "            children_right,\n",
    "            children_default,\n",
    "            features,\n",
    "            thresholds,\n",
    "            values,\n",
    "            node_sample_weight,\n",
    "            x,\n",
    "            x_missing,\n",
    "            phi,\n",
    "            hot_index,\n",
    "            unique_depth + 1,\n",
    "            feature_indexes,\n",
    "            zero_fractions,\n",
    "            one_fractions,\n",
    "            pweights,\n",
    "            hot_zero_fraction * incoming_zero_fraction,\n",
    "            incoming_one_fraction,\n",
    "            split_index,\n",
    "            condition,\n",
    "            condition_feature,\n",
    "            hot_condition_fraction,\n",
    "        )\n",
    "\n",
    "        tree_shap_recursive(\n",
    "            children_left,\n",
    "            children_right,\n",
    "            children_default,\n",
    "            features,\n",
    "            thresholds,\n",
    "            values,\n",
    "            node_sample_weight,\n",
    "            x,\n",
    "            x_missing,\n",
    "            phi,\n",
    "            cold_index,\n",
    "            unique_depth + 1,\n",
    "            feature_indexes,\n",
    "            zero_fractions,\n",
    "            one_fractions,\n",
    "            pweights,\n",
    "            cold_zero_fraction * incoming_zero_fraction,\n",
    "            0,\n",
    "            split_index,\n",
    "            condition,\n",
    "            condition_feature,\n",
    "            cold_condition_fraction,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Compare runtime of XGBoost Tree SHAP..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:51.429322880Z",
     "start_time": "2023-10-14T15:13:51.285956536Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1391134262084961\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "is_sparse is deprecated and will be removed in a future version. Check `isinstance(dtype, pd.SparseDtype)` instead.\n",
      "is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n",
      "is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n",
      "is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "shap_values = bst.predict(xgboost.DMatrix(X), pred_contribs=True)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### Versus the Python (numba) Tree SHAP..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:13:51.678083888Z",
     "start_time": "2023-10-14T15:13:51.429216249Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": "array([-0.5140118 ,  0.07819144,  0.09588052, -0.01004643,  0.21487504,\n        0.37846563,  0.15740923, -0.04002657,  2.04526993])"
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.ones(X.shape[1])\n",
    "TreeExplainer(model).shap_values(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-10-14T15:14:01.899083922Z",
     "start_time": "2023-10-14T15:13:51.678228170Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0075643062591552734\n",
      "10.171732664108276\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "ex = TreeExplainer(model)\n",
    "print(time.time() - start)\n",
    "start = time.time()\n",
    "ex.shap_values(X.iloc[:, :])\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### ...about a hundred times slower in python, even with numba"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
